Interaction clustering, PID and primary particles

Interaction clustering is done by another Graph Neural Network (GNN). Each node corresponds to a predicted particle. In addition to predicting which edge should be kept (i.e. the interaction clustering), the chain also predicts for each node a particle type (particle identification, PID) and a binary classification into primary/non-primary particles. We call primary particles the first particles to come out of an interaction vertex.

Imports and configuration

If needed, you can edit the path to lartpc_mlreco3d library and to the data folder.

import os
SOFTWARE_DIR = '%s/lartpc_mlreco3d' % os.environ.get('HOME') 
DATA_DIR = '../data'

The usual imports and setting the right PYTHON_PATH… click if you need to see them.

import sys, os
# set software directory
sys.path.insert(0, SOFTWARE_DIR)
import numpy as np
import yaml
import torch
import plotly
import plotly.graph_objs as go
from plotly.offline import iplot, init_notebook_mode
init_notebook_mode(connected=False)

from mlreco.visualization import scatter_points, plotly_layout3d
from mlreco.visualization.gnn import scatter_clusters, network_topology, network_schematic
from mlreco.utils.ppn import uresnet_ppn_type_point_selector
from mlreco.utils.cluster.dense_cluster import fit_predict_np, gaussian_kernel
from mlreco.main_funcs import process_config, prepare
from mlreco.utils.gnn.cluster import get_cluster_label
from mlreco.utils.deghosting import adapt_labels_numpy as adapt_labels
from mlreco.visualization.gnn import network_topology

from larcv import larcv
/usr/local/lib/python3.6/dist-packages/MinkowskiEngine/__init__.py:42: UserWarning:

The environment variable `OMP_NUM_THREADS` not set. MinkowskiEngine will automatically set `OMP_NUM_THREADS=16`. If you want to set `OMP_NUM_THREADS` manually, please export it on the command line before running a python script. e.g. `export OMP_NUM_THREADS=12; python your_program.py`. It is recommended to set it below 24.
Welcome to JupyROOT 6.22/02

The configuration is loaded from the file inference.cfg.

cfg=yaml.load(open('%s/inference.cfg' % DATA_DIR, 'r').read().replace('DATA_DIR', DATA_DIR),Loader=yaml.Loader)
# pre-process configuration (checks + certain non-specified default settings)
process_config(cfg)
# prepare function configures necessary "handlers"
hs=prepare(cfg)
Config processed at: Linux volt002 3.10.0-1160.21.1.el7.x86_64 #1 SMP Tue Mar 16 18:28:22 UTC 2021 x86_64 x86_64 x86_64 GNU/Linux

$CUDA_VISIBLE_DEVICES="0"

{   'iotool': {   'batch_size': 4,
                  'collate_fn': 'CollateSparse',
                  'dataset': {   'data_keys': [   '../data/wire_mpvmpr_2020_04_test_small.root'],
                                 'limit_num_files': 10,
                                 'name': 'LArCVDataset',
                                 'schema': {   'cluster_label': [   'parse_cluster3d_clean_full',
                                                                    'cluster3d_pcluster',
                                                                    'particle_pcluster',
                                                                    'particle_mpv',
                                                                    'sparse3d_pcluster_semantics'],
                                               'input_data': [   'parse_sparse3d_scn',
                                                                 'sparse3d_reco',
                                                                 'sparse3d_reco_chi2'],
                                               'kinematics_label': [   'parse_cluster3d_kinematics_clean',
                                                                       'cluster3d_pcluster',
                                                                       'particle_corrected',
                                                                       'particle_mpv',
                                                                       'sparse3d_pcluster_semantics'],
                                               'particle_graph': [   'parse_particle_graph_corrected',
                                                                     'particle_corrected',
                                                                     'cluster3d_pcluster'],
                                               'particles_asis': [   'parse_particle_asis',
                                                                     'particle_pcluster',
                                                                     'cluster3d_pcluster'],
                                               'particles_label': [   'parse_particle_points_with_tagging',
                                                                      'sparse3d_pcluster',
                                                                      'particle_corrected'],
                                               'segment_label': [   'parse_sparse3d_scn',
                                                                    'sparse3d_pcluster_semantics_ghost']}},
                  'minibatch_size': 4,
                  'num_workers': 4,
                  'shuffle': False},
    'model': {   'loss_input': [   'segment_label',
                                   'particles_label',
                                   'cluster_label',
                                   'kinematics_label',
                                   'particle_graph'],
                 'modules': {   'chain': {   'enable_cnn_clust': True,
                                             'enable_cosmic': True,
                                             'enable_ghost': True,
                                             'enable_gnn_inter': True,
                                             'enable_gnn_kinematics': True,
                                             'enable_gnn_particle': False,
                                             'enable_gnn_shower': True,
                                             'enable_gnn_track': True,
                                             'enable_ppn': True,
                                             'enable_uresnet': True,
                                             'use_ppn_in_gnn': True,
                                             'verbose': True},
                                'cosmic_discriminator': {   'network_base': {   'data_dim': 3,
                                                                                'features': 4,
                                                                                'leakiness': 0.33,
                                                                                'spatial_size': 768},
                                                            'res_encoder': {   'coordConv': True,
                                                                               'latent_size': 2,
                                                                               'pool_mode': 'avg'},
                                                            'uresnet_encoder': {   'features': 16,
                                                                                   'filters': 16,
                                                                                   'num_strides': 9},
                                                            'use_input_data': False,
                                                            'use_true_interactions': False},
                                'cosmic_loss': {   'node_loss': {   'balance_classes': True,
                                                                    'name': 'type',
                                                                    'target_col': 8}},
                                'dbscan_frag': {   'cluster_classes': [0, 2, 3],
                                                   'delta_label': 3,
                                                   'dim': 3,
                                                   'eps': [   1.999,
                                                              3.999,
                                                              1.999,
                                                              4.999],
                                                   'michel_label': 2,
                                                   'min_samples': 1,
                                                   'min_size': [3, 3, 3, 3],
                                                   'num_classes': 4,
                                                   'ppn_distance_threshold': 1.999,
                                                   'ppn_mask_radius': 5,
                                                   'ppn_score_threshold': 0.9,
                                                   'ppn_type_threshold': 0.3,
                                                   'track_clustering_method': 'closest_path',
                                                   'track_label': 1},
                                'full_chain_loss': {   'clustering_weight': 1.0,
                                                       'cosmic_weight': 1.0,
                                                       'flow_weight': 1.0,
                                                       'inter_gnn_weight': 1.0,
                                                       'kinematics_p_weight': 1.0,
                                                       'kinematics_type_weight': 1.0,
                                                       'kinematics_weight': 1.0,
                                                       'particle_gnn_weight': 1.0,
                                                       'ppn_weight': 1.0,
                                                       'segmentation_weight': 1.0,
                                                       'shower_gnn_weight': 1.0,
                                                       'track_gnn_weight': 1.0},
                                'grappa_inter': {   'base': {   'add_start_dir': True,
                                                                'add_start_point': True,
                                                                'group_pred': 'score',
                                                                'kinematics_mlp': True,
                                                                'kinematics_momentum': False,
                                                                'kinematics_type': True,
                                                                'node_min_size': 3,
                                                                'node_type': [   0,
                                                                                 1,
                                                                                 2,
                                                                                 3],
                                                                'start_dir_max_dist': 5,
                                                                'vertex_mlp': True},
                                                    'edge_encoder': {   'name': 'geo',
                                                                        'use_numpy': False},
                                                    'gnn_model': {   'aggr': 'add',
                                                                     'edge_classes': 2,
                                                                     'edge_feats': 19,
                                                                     'edge_output_feats': 64,
                                                                     'leakiness': 0.1,
                                                                     'name': 'meta',
                                                                     'node_classes': 2,
                                                                     'node_feats': 28,
                                                                     'node_output_feats': 64,
                                                                     'num_mp': 3},
                                                    'node_encoder': {   'name': 'geo',
                                                                        'use_numpy': False},
                                                    'type_net': {   'num_hidden': 32},
                                                    'vertex_net': {   'num_hidden': 32}},
                                'grappa_inter_loss': {   'edge_loss': {   'balance_classes': False,
                                                                          'high_purity': False,
                                                                          'loss': 'CE',
                                                                          'name': 'channel',
                                                                          'reduction': 'sum',
                                                                          'source_col': 6,
                                                                          'target': 'group',
                                                                          'target_col': 7},
                                                         'node_loss': {   'balance_classes': True,
                                                                          'name': 'kinematics',
                                                                          'spatial_size': 768,
                                                                          'type_loss': 'CE'}},
                                'grappa_kinematics': {   'base': {   'edge_dist_metric': 'set',
                                                                     'edge_dist_numpy': True,
                                                                     'edge_max_dist': -1,
                                                                     'kinematics_mlp': True,
                                                                     'kinematics_momentum': True,
                                                                     'kinematics_type': False,
                                                                     'network': 'complete',
                                                                     'node_min_size': -1,
                                                                     'node_type': -1},
                                                         'edge_encoder': {   'cnn_encoder': {   'name': 'cnn2',
                                                                                                'network_base': {   'data_dim': 3,
                                                                                                                    'features': 4,
                                                                                                                    'leakiness': 0.33,
                                                                                                                    'spatial_size': 768},
                                                                                                'res_encoder': {   'coordConv': True,
                                                                                                                   'latent_size': 32,
                                                                                                                   'pool_mode': 'avg'},
                                                                                                'uresnet_encoder': {   'filters': 32,
                                                                                                                       'input_kernel': 3,
                                                                                                                       'num_classes': 5,
                                                                                                                       'num_filters': 32,
                                                                                                                       'num_strides': 9,
                                                                                                                       'reps': 2}},
                                                                             'geo_encoder': {   'more_feats': True},
                                                                             'name': 'mix_debug',
                                                                             'normalize': True},
                                                         'gnn_model': {   'aggr': 'add',
                                                                          'edge_classes': 2,
                                                                          'edge_feats': 51,
                                                                          'edge_output_feats': 64,
                                                                          'leak': 0.33,
                                                                          'name': 'nnconv_old',
                                                                          'node_classes': 5,
                                                                          'node_feats': 83,
                                                                          'node_output_feats': 128,
                                                                          'num_mp': 3},
                                                         'momentum_net': {   'num_hidden': 32},
                                                         'node_encoder': {   'cnn_encoder': {   'name': 'cnn2',
                                                                                                'network_base': {   'data_dim': 3,
                                                                                                                    'features': 4,
                                                                                                                    'leakiness': 0.33,
                                                                                                                    'spatial_size': 768},
                                                                                                'res_encoder': {   'coordConv': True,
                                                                                                                   'latent_size': 64,
                                                                                                                   'pool_mode': 'avg'},
                                                                                                'uresnet_encoder': {   'filters': 32,
                                                                                                                       'input_kernel': 3,
                                                                                                                       'num_classes': 5,
                                                                                                                       'num_filters': 16,
                                                                                                                       'num_strides': 9,
                                                                                                                       'reps': 2}},
                                                                             'geo_encoder': {   'more_feats': True},
                                                                             'name': 'mix_debug',
                                                                             'normalize': True},
                                                         'use_true_particles': False},
                                'grappa_kinematics_loss': {   'edge_loss': {   'balance_classes': False,
                                                                               'high_purity': False,
                                                                               'name': 'channel',
                                                                               'reduction': 'sum',
                                                                               'target': 'particle_forest'},
                                                              'node_loss': {   'name': 'kinematics',
                                                                               'reg_loss': 'l2'}},
                                'grappa_particle': {   'base': {   'node_min_size': 10,
                                                                   'node_type': -1},
                                                       'edge_encoder': {   'name': 'geo',
                                                                           'use_numpy': True},
                                                       'gnn_model': {   'aggr': 'add',
                                                                        'edge_classes': 2,
                                                                        'edge_feats': 19,
                                                                        'edge_output_feats': 64,
                                                                        'leakiness': 0.1,
                                                                        'name': 'meta',
                                                                        'node_classes': 2,
                                                                        'node_feats': 24,
                                                                        'node_output_feats': 64,
                                                                        'num_mp': 3},
                                                       'node_encoder': {   'name': 'geo',
                                                                           'use_numpy': True}},
                                'grappa_particle_loss': {   'edge_loss': {   'balance_classes': False,
                                                                             'high_purity': True,
                                                                             'loss': 'CE',
                                                                             'name': 'channel',
                                                                             'reduction': 'sum',
                                                                             'source_col': 5,
                                                                             'target': 'group',
                                                                             'target_col': 6},
                                                            'node_loss': {   'balance_classes': False,
                                                                             'group_pred_alg': 'score',
                                                                             'high_purity': True,
                                                                             'loss': 'CE',
                                                                             'name': 'primary',
                                                                             'reduction': 'sum',
                                                                             'use_group_pred': True}},
                                'grappa_shower': {   'base': {   'add_start_dir': True,
                                                                 'add_start_point': True,
                                                                 'node_min_size': 3,
                                                                 'node_type': 0,
                                                                 'start_dir_max_dist': 5},
                                                     'edge_encoder': {   'name': 'geo',
                                                                         'use_numpy': False},
                                                     'gnn_model': {   'aggr': 'add',
                                                                      'edge_classes': 2,
                                                                      'edge_feats': 19,
                                                                      'edge_output_feats': 64,
                                                                      'leakiness': 0.1,
                                                                      'name': 'meta',
                                                                      'node_classes': 2,
                                                                      'node_feats': 28,
                                                                      'node_output_feats': 64,
                                                                      'num_mp': 3},
                                                     'node_encoder': {   'name': 'geo',
                                                                         'use_numpy': False}},
                                'grappa_shower_loss': {   'edge_loss': {   'high_purity': True,
                                                                           'name': 'channel',
                                                                           'source_col': 5,
                                                                           'target_col': 6},
                                                          'node_loss': {   'group_pred_alg': 'score',
                                                                           'high_purity': True,
                                                                           'name': 'primary',
                                                                           'use_group_pred': True}},
                                'grappa_track': {   'base': {   'add_start_dir': True,
                                                                'add_start_point': True,
                                                                'node_min_size': 3,
                                                                'node_type': 1,
                                                                'start_dir_max_dist': 5},
                                                    'edge_encoder': {   'name': 'geo',
                                                                        'use_numpy': False},
                                                    'gnn_model': {   'aggr': 'add',
                                                                     'edge_classes': 2,
                                                                     'edge_feats': 19,
                                                                     'edge_output_feats': 64,
                                                                     'leakiness': 0.1,
                                                                     'name': 'meta',
                                                                     'node_classes': 2,
                                                                     'node_feats': 28,
                                                                     'node_output_feats': 64,
                                                                     'num_mp': 3},
                                                    'node_encoder': {   'name': 'geo',
                                                                        'use_numpy': False}},
                                'grappa_track_loss': {   'edge_loss': {   'high_purity': False,
                                                                          'name': 'channel',
                                                                          'source_col': 5,
                                                                          'target_col': 6}},
                                'spice': {   'fragment_clustering': {   'cluster_all': False,
                                                                        'cluster_classes': [   1],
                                                                        'min_frag_size': 10,
                                                                        'min_voxels': 2,
                                                                        'p_thresholds': [   0.95,
                                                                                            0.95,
                                                                                            0.95,
                                                                                            0.95],
                                                                        's_thresholds': [   0.0,
                                                                                            0.0,
                                                                                            0.0,
                                                                                            0.35]},
                                             'network_base': {   'data_dim': 3,
                                                                 'features': 4,
                                                                 'leakiness': 0.33,
                                                                 'spatial_size': 768},
                                             'spatial_embeddings': {   'coordConv': True,
                                                                       'embedding_dim': 3,
                                                                       'seediness_dim': 1,
                                                                       'sigma_dim': 1},
                                             'uresnet': {   'filters': 64,
                                                            'input_kernel_size': 7,
                                                            'num_strides': 7,
                                                            'reps': 2}},
                                'spice_loss': {   'embedding_weight': 1.0,
                                                  'mask_loss_fn': 'lovasz_hinge',
                                                  'min_voxels': 2,
                                                  'name': 'se_vectorized_inter',
                                                  'seediness_weight': 1.0,
                                                  'smoothing_weight': 1.0},
                                'uresnet_ppn': {   'ppn': {   'classify_endpoints': True,
                                                              'data_dim': 3,
                                                              'downsample_ghost': True,
                                                              'filters': 16,
                                                              'model_name': 'ppn',
                                                              'model_path': '../data/weights_ppn3_snapshot-1999.ckpt',
                                                              'num_classes': 5,
                                                              'num_strides': 6,
                                                              'ppn1_size': 24,
                                                              'ppn2_size': 96,
                                                              'ppn_num_conv': 1,
                                                              'score_threshold': 0.5,
                                                              'spatial_size': 768,
                                                              'use_encoding': False,
                                                              'weight_ppn': 0.9},
                                                   'uresnet_lonely': {   'data_dim': 3,
                                                                         'features': 2,
                                                                         'filters': 16,
                                                                         'freeze': False,
                                                                         'ghost': True,
                                                                         'leakiness': 0.0,
                                                                         'num_classes': 5,
                                                                         'num_strides': 6,
                                                                         'spatial_size': 768,
                                                                         'weight_loss': True}}},
                 'name': 'full_chain',
                 'network_input': ['input_data']},
    'trainval': {   'checkpoint_step': 100,
                    'concat_result': [   'seediness',
                                         'margins',
                                         'embeddings',
                                         'fragments',
                                         'fragments_seg',
                                         'shower_fragments',
                                         'shower_edge_index',
                                         'shower_edge_pred',
                                         'shower_node_pred',
                                         'shower_group_pred',
                                         'track_fragments',
                                         'track_edge_index',
                                         'track_node_pred',
                                         'track_edge_pred',
                                         'track_group_pred',
                                         'particle_fragments',
                                         'particle_edge_index',
                                         'particle_node_pred',
                                         'particle_edge_pred',
                                         'particle_group_pred',
                                         'particles',
                                         'inter_edge_index',
                                         'inter_node_pred',
                                         'inter_edge_pred',
                                         'node_pred_p',
                                         'node_pred_type',
                                         'flow_edge_pred',
                                         'kinematics_particles',
                                         'kinematics_edge_index',
                                         'clust_fragments',
                                         'clust_frag_seg',
                                         'interactions',
                                         'inter_cosmic_pred',
                                         'node_pred_vtx',
                                         'total_num_points',
                                         'total_nonghost_points'],
                    'debug': False,
                    'gpus': [0],
                    'iterations': 652,
                    'log_dir': './log_trash',
                    'minibatch_size': -1,
                    'model_path': '../data/weights_full5_snapshot-999.cpkt',
                    'optimizer': {'args': {'lr': 0.001}, 'name': 'Adam'},
                    'report_step': 1,
                    'seed': 123,
                    'train': False,
                    'unwrapper': 'unwrap_3d_scn',
                    'weight_prefix': './weights_trash/snapshot'}}
Loading file: ../data/wire_mpvmpr_2020_04_test_small.root
Loading tree sparse3d_reco
Loading tree sparse3d_reco_chi2
Loading tree sparse3d_pcluster_semantics_ghost
Loading tree cluster3d_pcluster
Loading tree particle_pcluster
Loading tree particle_mpv
Loading tree sparse3d_pcluster_semantics
Loading tree sparse3d_pcluster
Loading tree particle_corrected
Sequential(
  (0): Sequential(
    (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
    (1): NetworkInNetwork64->4
  )
  (1): OutputLayer()
)
ClusterCNN(
  (input): Sequential(
    (0): InputLayer()
    (1): SubmanifoldConvolution 4->64 C7
  )
  (concat): JoinTable()
  (add): AddTable()
  (encoding_block): Sequential(
    (0): Sequential(
      (0): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 64->64 C3
          (2): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 64->64 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 64->64 C3
          (2): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 64->64 C3
        )
      )
      (3): AddTable()
    )
    (1): Sequential(
      (0): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 128->128 C3
          (2): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 128->128 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 128->128 C3
          (2): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 128->128 C3
        )
      )
      (3): AddTable()
    )
    (2): Sequential(
      (0): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 192->192 C3
          (2): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 192->192 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 192->192 C3
          (2): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 192->192 C3
        )
      )
      (3): AddTable()
    )
    (3): Sequential(
      (0): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 256->256 C3
          (2): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 256->256 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 256->256 C3
          (2): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 256->256 C3
        )
      )
      (3): AddTable()
    )
    (4): Sequential(
      (0): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 320->320 C3
          (2): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 320->320 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 320->320 C3
          (2): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 320->320 C3
        )
      )
      (3): AddTable()
    )
    (5): Sequential(
      (0): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 384->384 C3
          (2): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 384->384 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 384->384 C3
          (2): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 384->384 C3
        )
      )
      (3): AddTable()
    )
    (6): Sequential(
      (0): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(448,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 448->448 C3
          (2): BatchNormLeakyReLU(448,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 448->448 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(448,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 448->448 C3
          (2): BatchNormLeakyReLU(448,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 448->448 C3
        )
      )
      (3): AddTable()
    )
  )
  (encoding_conv): Sequential(
    (0): Sequential(
      (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Convolution 64->128 C2/2
    )
    (1): Sequential(
      (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Convolution 128->192 C2/2
    )
    (2): Sequential(
      (0): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Convolution 192->256 C2/2
    )
    (3): Sequential(
      (0): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Convolution 256->320 C2/2
    )
    (4): Sequential(
      (0): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Convolution 320->384 C2/2
    )
    (5): Sequential(
      (0): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Convolution 384->448 C2/2
    )
    (6): Sequential()
  )
  (decoding_block): Sequential(
    (0): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork768->384
        (1): Sequential(
          (0): BatchNormLeakyReLU(768,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 768->384 C3
          (2): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 384->384 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 384->384 C3
          (2): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 384->384 C3
        )
      )
      (3): AddTable()
    )
    (1): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork640->320
        (1): Sequential(
          (0): BatchNormLeakyReLU(640,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 640->320 C3
          (2): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 320->320 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 320->320 C3
          (2): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 320->320 C3
        )
      )
      (3): AddTable()
    )
    (2): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork512->256
        (1): Sequential(
          (0): BatchNormLeakyReLU(512,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 512->256 C3
          (2): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 256->256 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 256->256 C3
          (2): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 256->256 C3
        )
      )
      (3): AddTable()
    )
    (3): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork384->192
        (1): Sequential(
          (0): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 384->192 C3
          (2): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 192->192 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 192->192 C3
          (2): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 192->192 C3
        )
      )
      (3): AddTable()
    )
    (4): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork256->128
        (1): Sequential(
          (0): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 256->128 C3
          (2): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 128->128 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 128->128 C3
          (2): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 128->128 C3
        )
      )
      (3): AddTable()
    )
    (5): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork128->64
        (1): Sequential(
          (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 128->64 C3
          (2): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 64->64 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 64->64 C3
          (2): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 64->64 C3
        )
      )
      (3): AddTable()
    )
  )
  (decoding_conv): Sequential(
    (0): Sequential(
      (0): BatchNormLeakyReLU(448,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 448->384 C2/2
    )
    (1): Sequential(
      (0): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 384->320 C2/2
    )
    (2): Sequential(
      (0): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 320->256 C2/2
    )
    (3): Sequential(
      (0): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 256->192 C2/2
    )
    (4): Sequential(
      (0): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 192->128 C2/2
    )
    (5): Sequential(
      (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 128->64 C2/2
    )
  )
  (decoding_block2): Sequential(
    (0): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork768->384
        (1): Sequential(
          (0): BatchNormLeakyReLU(768,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 768->384 C3
          (2): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 384->384 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 384->384 C3
          (2): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 384->384 C3
        )
      )
      (3): AddTable()
    )
    (1): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork640->320
        (1): Sequential(
          (0): BatchNormLeakyReLU(640,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 640->320 C3
          (2): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 320->320 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 320->320 C3
          (2): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 320->320 C3
        )
      )
      (3): AddTable()
    )
    (2): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork512->256
        (1): Sequential(
          (0): BatchNormLeakyReLU(512,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 512->256 C3
          (2): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 256->256 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 256->256 C3
          (2): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 256->256 C3
        )
      )
      (3): AddTable()
    )
    (3): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork384->192
        (1): Sequential(
          (0): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 384->192 C3
          (2): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 192->192 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 192->192 C3
          (2): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 192->192 C3
        )
      )
      (3): AddTable()
    )
    (4): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork256->128
        (1): Sequential(
          (0): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 256->128 C3
          (2): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 128->128 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 128->128 C3
          (2): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 128->128 C3
        )
      )
      (3): AddTable()
    )
    (5): Sequential(
      (0): ConcatTable(
        (0): NetworkInNetwork128->64
        (1): Sequential(
          (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 128->64 C3
          (2): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 64->64 C3
        )
      )
      (1): AddTable()
      (2): ConcatTable(
        (0): Identity()
        (1): Sequential(
          (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (1): SubmanifoldConvolution 64->64 C3
          (2): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
          (3): SubmanifoldConvolution 64->64 C3
        )
      )
      (3): AddTable()
    )
  )
  (decoding_conv2): Sequential(
    (0): Sequential(
      (0): BatchNormLeakyReLU(448,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 448->384 C2/2
    )
    (1): Sequential(
      (0): BatchNormLeakyReLU(384,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 384->320 C2/2
    )
    (2): Sequential(
      (0): BatchNormLeakyReLU(320,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 320->256 C2/2
    )
    (3): Sequential(
      (0): BatchNormLeakyReLU(256,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 256->192 C2/2
    )
    (4): Sequential(
      (0): BatchNormLeakyReLU(192,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 192->128 C2/2
    )
    (5): Sequential(
      (0): BatchNormLeakyReLU(128,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): Deconvolution 128->64 C2/2
    )
  )
  (outputEmbeddings): Sequential(
    (0): Sequential(
      (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): NetworkInNetwork64->4
    )
    (1): OutputLayer()
  )
  (outputSeediness): Sequential(
    (0): Sequential(
      (0): BatchNormLeakyReLU(64,eps=0.0001,momentum=0.99,affine=True,leakiness=0.33)
      (1): NetworkInNetwork64->1
    )
    (1): OutputLayer()
  )
  (tanh): Tanh()
  (sigmoid): Sigmoid()
)
Total Number of Trainable Parameters = 175210432
Restoring weights for  from ../data/weights_full5_snapshot-999.cpkt...
INCOMPATIBLE KEYS!
['module.uresnet_lonely.bn.weight', 'module.uresnet_lonely.bn.bias', 'module.uresnet_lonely.bn.running_mean', 'module.uresnet_lonely.bn.running_var']
make sure your module is named  
Done.
Restoring weights for ppn from ../data/weights_ppn3_snapshot-1999.ckpt...
Done.

The output is hidden because it reprints the entire (lengthy) configuration. Feel free to take a look if you are curious!

Finally we run the chain for 1 iteration:

# Call forward to run the net, store the output in "res"
data, output = hs.trainer.forward(hs.data_io_iter)
Segmentation Accuracy: 0.9882
PPN Accuracy: 0.8556
Clustering Accuracy: 0.9465
Shower fragment clustering accuracy: 1.0000
Shower primary prediction accuracy: 0.0000
Track fragment clustering accuracy: 0.9907
Particle ID accuracy: -1.2915
Interaction grouping accuracy: 0.9609
Flow accuracy: 0.9841
Type accuracy: 0.7500
Momentum accuracy: -3.7679
Vertex position accuracy: 0.6862
Vertex score accuracy: 0.9111
Cosmic discrimination accuracy: 0.8696

Now we can play with data and output to visualize what we are interested in. Feel free to change the entry index if you want to look at a different entry!

entry = 0

Let us grab the interesting quantities:

clust_label = data['cluster_label'][entry]
input_data = data['input_data'][entry]
segment_label = data['segment_label'][entry][:, -1]

ghost_mask = output['ghost'][entry].argmax(axis=1) == 0
segment_pred = output['segmentation'][entry].argmax(axis=1)

Visualization of interaction clustering

Because our small dataset has ghost points, we need to adapt the true cluster labels (which do not label ghost points by default). This will assign to true ghost points predicted as non-ghost points the label of the closest true non-ghost point. True ghost points which are correctly predicted as ghost points keep a label of -1 for everything.

clust_label_adapted = adapt_labels(output, data['segment_label'], data['cluster_label'])[entry]

clust_ids_true = get_cluster_label(torch.tensor(clust_label_adapted), output['particles'][entry], column=7)
clust_ids_pred = output['inter_group_pred'][entry]

Note that the function get_cluster_label uses the majority rule to determine the true label of a cluster of voxels (here, particles).

trace = []

trace += network_topology(data['input_data'][entry][ghost_mask],
                         output['particles'][entry],
                         #edge_index=output['frag_edge_index'][entry],
                         clust_labels=clust_ids_true,
                         markersize=2, cmin=0, cmax=10, colorscale=plotly.colors.qualitative.Dark24)
trace[-1].name = 'True interactions'


trace+= scatter_points(clust_label_adapted,markersize=1,color=clust_label_adapted[:, 7], colorscale=plotly.colors.qualitative.Dark24)
trace[-1].name = 'Adapted cluster labels'

trace += network_topology(data['input_data'][entry][ghost_mask],
                         output['particles'][entry],
                         #edge_index=output['frag_edge_index'][entry],
                         clust_labels=clust_ids_pred,
                         markersize=2, cmin=0, cmax=10, colorscale=plotly.colors.qualitative.Dark24)
trace[-1].name = 'Predicted interactions'

fig = go.Figure(data=trace,layout=plotly_layout3d())
fig.update_layout(legend=dict(x=1.1, y=0.9))

iplot(fig)

Primary particles predictions

We need to get the true labels first:

kinematics_label = data['kinematics_label'][entry]
true_vtx, inv = np.unique(kinematics_label[:, 9:12], axis=0, return_index=True)
true_vtx_primary = kinematics_label[inv, 12]

And the predictions:

vtx_primary_pred = output['node_pred_vtx'][entry][:, 3:].argmax(axis=1)

We need to take the argmax of the softmax scores. The predictions are 0 for non-primary and 1 for primary particle.

trace = []

trace+= scatter_points(kinematics_label,markersize=1,color=kinematics_label[:, 12], colorscale=plotly.colors.qualitative.Dark24)
trace[-1].name = 'True vertex primary particles'

trace += network_topology(data['input_data'][entry][ghost_mask],
                         output['particles'][entry],
                         #edge_index=output['frag_edge_index'][entry],
                         clust_labels=vtx_primary_pred,
                         markersize=2, cmin=0, cmax=10, colorscale=plotly.colors.qualitative.Dark24)
trace[-1].name = 'Predicted vertex primary particles'

fig = go.Figure(data=trace,layout=plotly_layout3d())
fig.update_layout(legend=dict(x=1.1, y=0.9))

iplot(fig)

Particle identification (PID)

The predictions are in node_pred_type:

type_pred = output['node_pred_type'][entry].argmax(axis=1)

Here is the meaning of each integer type:

Integer

Particle type

0

Photon (\(\gamma\))

1

Electron (\(e\))

2

Muon (\(\mu\))

3

Pion (\(\pi\))

4

Proton (\(p\))

trace = []

trace+= scatter_points(clust_label,markersize=1,color=kinematics_label[:, -2], colorscale=plotly.colors.qualitative.Dark24)
trace[-1].name = 'True particle type'

trace+= scatter_points(clust_label,markersize=1,color=type_pred, colorscale=plotly.colors.qualitative.Dark24)
trace[-1].name = 'Predicted particle type'

fig = go.Figure(data=trace,layout=plotly_layout3d())
fig.update_layout(legend=dict(x=1.1, y=0.9))

iplot(fig)